module Functions_wa
    
    use Globals
    use Toolbox
    use ModuleSaving

contains

    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!     Initial Guess for V  !!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    subroutine InitialVGuess_wa

    implicit none

    real(8), dimension(NS) :: hh, yk, yl, Ttot, incm, cc, uu

    hh = 0.5D0*hmax                     ! assume labor supply is half of maximum

    yk = r*S(:,1)                       ! investment income
    yl = wge*S(:,2)*hh                  ! labor income

    Ttot = TAX(yl,yk)                   ! tax payments

    incm = S(:,1) + yk + yl - Ttot      ! available resources
    cc   = (incm - S(:,1))/(1d0+tauc)   ! consumption assuming everyone keeps assets constant
    cc   = max(tiny,cc)                 ! ensure consumption is positive
    uu   = ( (cc**(1d0-sg))/(1d0-sg) ) - B*( (hh**(1d0+varphi))/(1d0+varphi) )  ! utility

    V    = uu/(1.0D0-beta)              ! value assuming everybody stays in same state forever

    call ComputeVE                      ! continuation values
    
    end subroutine InitialVGuess_wa

    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!     Compute Cont Values  !!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    subroutine ComputeVe

    implicit none

    integer             :: SSind(NS), ix_np

    Ve = 0.0D0          ! initialize continuation value

    do ix_np = 1, Nx
        SSind = aind + Na*(ix_np-1)
        Ve    = Ve + Px(xind,ix_np)*V(SSind)   ! continuation value
    end do

    end subroutine ComputeVe

    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!     Tax Function Vector              !!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    function TAX(yl_in,yk_in) result(T_out)

    implicit none

    ! Local variable declarations
    real(8), intent(in)               :: yl_in(:), yk_in(:)                 ! labor income, capital income
    real(8), dimension(size(yl_in,1)) :: yaux, inctax, transfers, T_out, et ! total income, income tax, transfers, net tax, auxiliary variable transfer
    real(8) :: chit                                                         ! auxiliary variable transfer

    ! total income (sum of labor and capital income)
    yaux = yl_in + yk_in     

    ! Tax
    inctax = exp(log(lambda)*((yl_in/ymean)**(-gamma)))*yl_in

    ! Transfers
    et   = exp(-rt*(yaux/ymean-omegat))
    chit = exp(rt*omegat)/(1.0D0 + exp(rt*omegat))

    transfers = (mt*ymean)*(et/(1.0D0+et))*(1.0D0/chit)
      
    ! Total net tax payment
    T_out = inctax + tauk*yk_in - transfers       

    end function TAX

    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!              Tax Function SCALAR     !!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    function TAX_sc(yl_in,yk_in) result(T_out)

    implicit none

    ! Local variable declarations
    real(8), intent(in) :: yl_in,yk_in              ! labor income, capital income
    real(8) :: inctax, transfers, T_out, yaux, et   ! income tax, transfers, net tax, total income, auxiliary variable transfer
    real(8) :: chit                                 ! auxiliary variable transfer

    ! total income (sum of labor and capital income)
    yaux = yl_in + yk_in     

    ! Tax
    inctax = exp(log(lambda)*((yl_in/ymean)**(-gamma)))*yl_in

    ! Transfers
    et   = exp(-rt*(yaux/ymean-omegat))
    chit = exp(rt*omegat)/(1.0D0 + exp(rt*omegat))

    transfers = (mt*ymean)*(et/(1.0D0+et))*(1D0/chit)
    
    ! total net tax payment
    T_out = inctax + tauk*yk_in - transfers      

    end function TAX_sc

    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!              Tax Function SCALAR     !!!
    !!!                 der wrt yl           !!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    function dTAX_sc(yl_in,yk_in) result(T_out)

    implicit none

    real(8), intent(in) :: yl_in, yk_in
    real(8) :: inctaxl, transfersl, T_out, yaux, et, chit
    
    ! total income (sum of labor and capital income)
    yaux = yl_in + yk_in     

    ! Tax
    inctaxl = exp(log(lambda)*((yl_in/ymean)**(-gamma)))*(1d0-gamma*log(lambda)*((yl_in/ymean)**(-gamma)))

    ! Transfers
    et = exp(-rt*(yaux/ymean-omegat))
    chit = exp(rt*omegat)/(1.0D0+exp(rt*omegat))
    transfersl = -(rt/ymean)*(mt*ymean)*(et/((1.0D0+et)**2d0))*(1D0/chit)

    ! total net tax
    T_out = inctaxl - transfersl      

    end function dTAX_sc


    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!              Tax Function SCALAR     !!!
    !!!                der2 wrt yl           !!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    function ddTAX_sc(yl_in,yk_in) result(T_out)

    implicit none

    real(8), intent(in) :: yl_in, yk_in
    real(8) :: inctaxll, transfersll, T_out, yaux, et, chit

    ! total income (sum of labor and capital income)
    yaux = yl_in + yk_in     
    
    ! Tax
    inctaxll = exp(log(lambda)*((yl_in/ymean)**(-gamma)))*gamma*log(lambda)*((yl_in/ymean)**(-gamma-1.0D0))*(1d0/ymean)
    inctaxll = inctaxll*(gamma-1D0+gamma*log(lambda)*((yl_in/ymean)**(-gamma)))

    ! Transfers
    et = exp(-rt*(yaux/ymean-omegat))
    chit = exp(rt*omegat)/(1.0D0 + exp(rt*omegat))
    transfersll = ((-rt/ymean)**2D0)*(mt*ymean)*(et/((1.0D0+et)**3d0))*(1d0-et)*(1D0/chit)

    ! total net tax
    T_out = inctaxll - transfersll     

    end function ddTAX_sc



    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!     VF eval (scalar) given a and h            !!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    function VF_eval_sc(h_in,a_in,is_in) result(V_out)

    implicit none
    
    ! Local variable declarations
    real(8), intent(in) :: h_in, a_in           ! policy function inputs
    integer, intent(in) :: is_in                ! state index input

    real(8) :: V_out                            ! value (output)
    real(8) :: yk, yl, caux, uu, Ve_np          ! capital, labor income, cons, inst. util., cont. value
    real(8) :: anp_vec(1)                       ! asset choice

    integer  :: aind_np, aind_np_vec(1)         ! asset choice index
    integer  :: Sind_np                         ! state index
    
    ! Income
    yk = r*S(is_in,1)
    yl = wge*S(is_in,2)*h_in
    
    ! Consumption
    caux = (S(is_in,1) + yk + yl - TAX_sc(yL,yK) - a_in)/(1d0+tauc)
    caux = max(caux, 1D-12)

    ! Instantaneous utility
    uu = ((caux**(1.0D0-sg))/(1.0D0-sg))-B*(((h_in)**(1.0D0+varphi))/(1.0D0+varphi))

    ! Continuation value
    anp_vec = a_in
    !aind_np_vec = vsearch_FET(anp_vec,a_vec)
    aind_np_vec = vsearchprm_FET(anp_vec,a_vec,aprm)
    aind_np = aind_np_vec(1)
    Sind_np = aind_np + Na*(xind(is_in)-1)

    Ve_np = Ve(Sind_np) + (a_in - a_vec(aind_np)) * ( Ve(Sind_np+1) - Ve(Sind_np) )/(a_vec(aind_np+1) - a_vec(aind_np))
    Ve_np = min(Ve_np , 1D-8)
    
    ! Value
    V_out = uu + beta*Ve_np

    end function VF_eval_sc


 

    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!    For a given scalar (is_in)               !!!
    !!!    Given a, compute V                       !!!
    !!!    hopt found with grid + Golden            !!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    function VF_eval_wa_hopt_sc(a_in,is_in,save_dum) result(V_out)

    implicit none
    
    ! Local variable declarations
    real(8), intent(in) :: a_in
    integer, intent(in) :: is_in, save_dum

    real(8) :: V_out, Vh(Nh), uuh(Nh), soc(Nh)
    real(8) :: yk, yl, ylaux, ytot, caux, uu, Ve_np, haux, h_in, soc_aux

    real(8) :: anp, anp_vec(1), aaux
    integer :: aind_np, aind_np_vec(1), aux(1), aux_int
    integer :: Sind_np, ih
    
    ! Asset choice
    anp = a_in

    ! Compute value with linear interpolation of continuation value in asset direction
    anp_vec = anp
    !aind_np_vec = vsearch_FET(anp_vec,a_vec)
    aind_np_vec = vsearchprm_FET(anp_vec,a_vec,aprm)
    aind_np = aind_np_vec(1)
    Sind_np = aind_np + Na*(xind(is_in)-1)

    Ve_np = Ve(Sind_np) + (anp - a_vec(aind_np)) * ( Ve(Sind_np+1) - Ve(Sind_np) )/(a_vec(aind_np+1) - a_vec(aind_np))
    Ve_np = min(Ve_np,1D-8)

    ! Income and assets
    yK    = r*S(is_in,1)    ! capital income
    yLaux = wge*S(is_in,2)  ! income (wage times productivity) per unit of work
    aaux  = S(is_in,1)      ! assets

    ! h on a grid
    do ih = 1, Nh
        h_in = h_vec(ih)    ! pick h
        yL   = yLaux*h_in   ! labor income for that hours choice
        ytot = yl + yk      ! total income for that hours choice

        caux = (aaux + yK + yL - TAX_sc(yL,yK) - anp)/(1d0+tauc)     ! consumption for that hours choice

        if (caux>0.0D0) then    ! check that hours choice implies feasible consumption
            soc_aux = -sg*((caux**(-sg-1d0))/(1d0+tauc))*(yLaux**2D0)*((1d0-dTAX_sc(yL,yK))**2D0)
            soc_aux = soc_aux - ((caux**(-sg))/(1d0+tauc))*(yLaux**2D0)*ddTAX_sc(yL,yK)
            soc(ih) = soc_aux - varphi*B*(h_in**(varphi-1d0))
            uuh(ih) = ((caux**(1.0D0-sg))/(1.0D0-sg)) - B*(((h_in)**(1.0D0+varphi))/(1.0D0+varphi))
        else
            soc(ih) = -1D0
            uuh(ih) = -1000D0
        endif

    enddo

    ! Pick best labor supply on grid
    aux = maxloc(uuh)
    aux_int = aux(1)
    
    ! Initialize haux to call golden if Newton not called
    haux = -2d0

    ! if soc<0, use Newton
    if ((maxval(soc)<0.0D0)) then ! issues when starting at h=0
        haux = Newton_uh_sc(anp,h_vec(max(aux_int,2)),is_in) 
    endif

    ! else, use Golden (if Newton ruled out by SOC; Newton stuck; or Newton result outside of bounds)
    if ( .NOT. ( (h_vec(max(aux_int-1,1)) <= haux) .AND. (haux <= h_vec(min(Nh,aux_int+1)) ) ) ) then ! =-2 if maxval soc >0, =-1 if stuck in Newton
        haux = SolveBell_uh_sc(anp,h_vec(max(aux_int-1,1)),h_vec(min(Nh,aux_int+1)),is_in) 
    endif

    ! Compute utility
    yL   = yLaux*haux                               ! labor income
    caux = (aaux + yK + yL - TAX_sc(yL,yK) - anp)/(1d0+tauc)       ! consumption
    caux = max(caux,1D-12)                          ! consumption
    uu   = ((caux**(1.0D0-sg))/(1.0D0-sg)) - B*(((haux)**(1.0D0+varphi))/(1.0D0+varphi))    ! instantaneous utility

    V_out = uu + beta*Ve_np
    
    ! Save
    if (save_dum==1) then 
        h_pol(is_in) = haux
        c_pol(is_in) = caux
    endif
    if (save_dum>1) then
        print *, 'Error algorithm'
    endif

    end function VF_eval_wa_hopt_sc

    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!!!  Golden Search [SCALAR] on a                     !!!!!
    !!!!!  h is computed optimally [NESTED GOLDEN]         !!!!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    function SolveBell_wa_hopt_sc(xl1,xr1,is_in) result(x_out)

    implicit none

    ! Local variable declarations
    real(8), intent(in)  :: xl1, xr1    ! bounds for golden search
    integer, intent(in)  :: is_in       ! state
    real(8) :: x_out                    ! output: asset choice
    real(8) :: xl0, xr0

    real(8) :: alpha1, alpha2
    real(8) :: d, x1, x2, f1, f2, x1new, x2new, f1new, f2new

    real(8) :: tol = 1D-10

    alpha1 = 0.5D0*(3.0D0 - sqrt(5.0D0))
    alpha2 = 0.5D0*(sqrt(5.0D0) - 1.0D0)

    xl0 = xl1
    xr0 = xr1

    if (xr0 < xl0) then
        xl0 = xr0
    endif

    d  = xr0 - xl0

    x1 = xl0 + alpha1*d
    x2 = xl0 + alpha2*d

    f1 = VF_eval_wa_hopt_sc(x1,is_in,0)
    f2 = VF_eval_wa_hopt_sc(x2,is_in,0)

    d  = alpha1*alpha2*d

    f1new = f1
    f2new = f2
    x1new = x1
    x2new = x2

    do while (d > tol)

    f1 = f1new
    f2 = f2new
    x1 = x1new
    x2 = x2new

    d = d*alpha2

    if (f2 < f1 ) then

        x2new = x1
        x1new = x1 - d

        f2new = f1
        f1new = VF_eval_wa_hopt_sc(x1-d,is_in,0)

    else

        x2new = x2 + d
        x1new = x2

        f2new = VF_eval_wa_hopt_sc(x2+d,is_in,0)
        f1new = f2

    endif

    end do

    if ( f2 < f1) then
        x_out = x1
    else
        x_out = x2
    endif

    end function SolveBell_wa_hopt_sc

    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!!!  Golden Search [SCALAR] on h, static opt on U     !!!!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    function SolveBell_uh_sc(a_in,xl1,xr1,is_in) result(x_out)

    implicit none

    ! Local variable declarations
    real(8), intent(in)  :: a_in, xl1, xr1  ! given asset choice, bounds for h
    integer, intent(in)  :: is_in           ! state
    real(8) :: x_out                        ! output: labor choice
    real(8) :: xl0, xr0

    real(8) :: alpha1, alpha2
    real(8) :: d, x1, x2, f1, f2, x1new, x2new, f1new, f2new

    real(8) :: tol = 1D-10

    real(8) :: h_in, yK, yLaux, yL, ytot, aaux
    real(8) :: caux, u_out

    aaux  = S(is_in,1)          ! assets 
    yK    = r*aaux              ! capital income
    yLaux = wge*S(is_in,2)      ! income (wage times productivity) per unit of work

    alpha1 = 0.5D0*(3.0D0 - sqrt(5.0D0))
    alpha2 = 0.5D0*(sqrt(5.0D0) - 1.0D0)

    xl0 = xl1
    xr0 = xr1

    if (xr0 < xl0) then
        xl0 = xr0
    endif

    d  = xr0 - xl0

    x1 = xl0 + alpha1*d
    x2 = xl0 + alpha2*d

    ! First evaluation, h = x1
    h_in = x1
    yL   = yLaux*h_in
    ytot = yL + yK

    caux = (aaux + ytot - TAX_sc(yL,yK) - a_in)/(1d0+tauc)
    caux = max(caux,1D-12)
    u_out = ( (caux**(1.0D0-sg))/(1.0D0-sg) ) - B*( ((h_in)**(1.0D0+varphi))/(1.0D0+varphi) )

    f1 = u_out

    ! Second evaluation, h  = x2
    h_in = x2
    yL   = yLaux*h_in
    ytot = yL + yK

    caux = (aaux + ytot - TAX_sc(yL,yK) - a_in)/(1d0+tauc)
    caux = max(caux,1D-12)
    u_out = ( (caux**(1.0D0-sg))/(1.0D0-sg) ) - B*( ((h_in)**(1.0D0+varphi))/(1.0D0+varphi) )

    f2 = u_out

    d  = alpha1*alpha2*d

    f1new = f1
    f2new = f2
    x1new = x1
    x2new = x2

    do while (d > tol)

    f1 = f1new
    f2 = f2new
    x1 = x1new
    x2 = x2new

    d = d*alpha2

    if (f2 < f1 ) then

        x2new = x1
        x1new = x1 - d

        f2new = f1
        ! First evaluation, h = x1-d
        h_in = x1-d
        yL   = yLaux*h_in
        ytot = yL + yK

        caux = (aaux + ytot - TAX_sc(yL,yK) - a_in)/(1d0+tauc)
        caux = max(caux,1D-12)
        u_out = ( (caux**(1.0D0-sg))/(1.0D0-sg) ) - B*( ((h_in)**(1.0D0+varphi))/(1.0D0+varphi) )

        f1new = u_out

    else

        x2new = x2 + d
        x1new = x2

        ! Second evaluation, h = x2+d
        h_in = (x2+d)
        yL   = yLaux*h_in
        ytot = yL + yK

        caux = (aaux + ytot - TAX_sc(yL,yK) - a_in)/(1d0+tauc)
        caux = max(caux,1D-12)
        u_out = ( (caux**(1.0D0-sg))/(1.0D0-sg) ) - B*( ((h_in)**(1.0D0+varphi))/(1.0D0+varphi) )

        f2new = u_out
        f1new = f2

    endif

    end do

    if ( f2 < f1) then
        x_out = x1
    else
        x_out = x2
    endif

    end function SolveBell_uh_sc


    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!!!  NEWTON [SCALAR] on h, static opt on U            !!!!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    function Newton_uh_sc(a_in,h_in,is_in) result(h_out)

    implicit none

    real(8), intent(in)  :: a_in, h_in  ! given asset input, starting point for h
    integer, intent(in)  :: is_in       ! state
    real(8) :: h_out                    ! output: labor choice

    real(8) :: h_init, h_end, tol, err, aux
    integer :: iter, maxiter

    real(8) :: caux, yK, yLaux, yL, ytot, aaux
    real(8) :: foc_out, soc_out

    tol = 1D-5      ! tolerance
    err = 1.0D0     ! initialize error
    iter = 0        ! initialize iteration counter
    maxiter = 30    ! max number of iterations

    aaux  = S(is_in,1)      ! assets
    yK    = r*aaux          ! capital income
    yLaux = wge*S(is_in,2)  ! income (wage times productivity) per hour worked

    h_init = h_in           ! starting point for hours

    do while ((err > tol) .and. (iter<maxiter))

        yL   = yLaux*h_init ! labor income
        ytot = yL + yK      ! total income

        caux = (aaux + yL + yK - TAX_sc(yL,yK) - a_in)/(1d0+tauc)      ! consumption
        caux = max(caux,1D-12)                          ! consumption

        ! FOC
        foc_out = ((caux**(-sg))/(1d0+tauc)  ) * yLaux * (1d0-dTAX_sc(yL,yK))
        foc_out = foc_out - B* (h_init**(varphi))

        ! SOC
        soc_out = -sg*((caux**(-sg-1d0))/(1d0+tauc)) * (yLaux**2D0) * ((1d0-dTAX_sc(yL,yK))**2D0)
        soc_out = soc_out - ((caux**(-sg))/(1d0+tauc)) * (yLaux**2D0) * ddTAX_sc(yL,yK)
        soc_out = soc_out - varphi*B* (h_init**(varphi-1d0))
        
        ! Newton step
        h_end = h_init - FOC_out/SOC_out
        err = abs(h_init-h_end)
        iter = iter + 1
        h_init = h_end

    enddo

    ! If Newton is not successful, set h_out such that golden is triggered
    if (iter>=maxiter-1) then
        h_out = -1.0D0
    else
        h_out = h_init
    endif
    
    end function Newton_uh_sc

end module Functions_wa
